from __future__ import division, print_function, absolute_import

import numpy as np
import pickle
from CDTools.tools.cmath import *
import torch as t
from RIP_tools import *

#
# This section runs the numerical experiment designed
# to test the robustness of RIP to poisson noise.
#

probe_maxk = 128

# Set GPU to false if you don't have a GPU
GPU = True

# This sets the size of the array we simulate the probe on
field_size = probe_maxk * 2
# This sets the focal spot size of our BLR probe
probe_fov_radius = (field_size * 4)//10

# This defines that we do our error comparison within the central
# half of the image.
error_check_padding = 1/ 4


# This generates an ideal BLR probe with the given parameters
nearfield = generate_blr_probe([field_size,field_size],
                               probe_fov_radius, 
                               probe_maxk)


# This defines a reconstruction with R = 0.4
resolution = int(probe_maxk * 0.8 / 2) * 2

# We store the final loss and error from each reconstruction
losses = []
errors = []

# We set up a sweep over different total photon counts in the image
photon_counts = np.logspace(3.5,8.5,60)
# area in pixels for a probe which covers 80% of the stage,
# in the low-resolution reconstructed object
num_pix_in_lr_obj = np.pi * (resolution * 0.4)**2 
photons_per_pixel = photon_counts / num_pix_in_lr_obj

for pc in photon_counts:
    print('Simulating photon count',pc)
    loop_losses = []
    loop_errors = []

    for attempt in range(50):
        print('Running attempt',attempt)
        # We generate a random image
        real = np.random.randn(resolution, resolution)
        imag = np.random.randn(resolution, resolution)
        original_image = real + 1j * imag
        
        original_image = complex_to_torch(original_image).to(dtype=t.float32)

        # All this does is upsample the probe in real space to avoid
        # aliasing in the probe-object interaction
        simulation_nearfield = expand_probe(nearfield, original_image)
        
        # Now we perform the simulation
        exit_wave = interact(original_image, simulation_nearfield)
        diffraction = propagate(exit_wave)
        intensities = measure(diffraction, nearfield.shape)

        # We normalize to the correct number of photons
        intensities /= t.sum(intensities)
        intensities *= pc

        # And then apply poisson statistics
        intensities = np.random.poisson(intensities.detach().cpu().numpy())
        intensities = t.Tensor(intensities).to(dtype=t.float32)

        # This attempts a reconstruction with a randomized initialization
        retrieval, loss = reconstruct(intensities, nearfield, resolution, 0.4, 1000, GPU=GPU, schedule=True)
        ret = torch_to_complex(retrieval)
        im = torch_to_complex(original_image)

        # We calculate the overall RMS error in the central region
        padding = int(resolution * error_check_padding)
        error, originals, comparisons = compare_images(ret, im, np.s_[padding:-padding,padding:-padding])

        # And save the RMS image error and diffraction loss
        loop_losses.append(loss[-1])
        loop_errors.append(error)


    errors.append(loop_errors)
    losses.append(loop_losses)
    
    

results = {'photon counts': photon_counts, 'probe': nearfield, 'losses':losses, 'errors':errors, 'photons per pixel':photons_per_pixel}
with open('data/poisson trial.pickle', 'wb') as f:
    pickle.dump(results, f)


